"""
# Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import argparse
import json
import math
import queue
import threading
import time
import traceback

import numpy as np
import paddle

from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import (
    EngineWorkerQueue,
    IPCSignal,
    shared_memory_exists,
)
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
from fastdeploy.utils import envs, get_logger

logger = get_logger("cache_messager", "cache_messager.log")


def parse_args():
    """
    从命令行解析参数
    """
    parser = argparse.ArgumentParser("Cache Messager")
    parser.add_argument(
        "--splitwise_role",
        type=str,
        default="mixed",
        help="splitwise role, can be decode, prefill or mixed",
    )
    parser.add_argument("--rank", type=int, default=0, help="current rank")
    parser.add_argument("--device_id", type=int, default=0, help="device id")
    parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
    parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
    parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
    parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
    parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
    parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
    parser.add_argument(
        "--protocol",
        type=str,
        default="ipc",
        help="cache transfer protocol, only surport ipc now",
    )
    parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
    parser.add_argument("--cache_queue_port", type=int, default=9924, help="cache queue port")
    parser.add_argument(
        "--engine_worker_queue_port",
        type=int,
        default=9923,
        help="engine worker queue port",
    )
    parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
    parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
    parser.add_argument(
        "--cache_dtype",
        type=str,
        default="bfloat16",
        choices=["uint8", "bfloat16"],
        help="cache dtype",
    )
    parser.add_argument(
        "--speculative_config",
        type=json.loads,
        default="{}",
        help="speculative config",
    )
    parser.add_argument("--local_data_parallel_id", type=int, default=0)

    args = parser.parse_args()
    return args


class CacheMessager:
    """
    CacheMessager is used to send the cache data between the engine worker and the cache server.
    """

    def __init__(
        self,
        splitwise_role,
        transfer_protocol,
        pod_ip,
        engine_worker_queue_port,
        local_data_parallel_id,
        gpu_cache_kvs,
        rank,
        nranks,
        num_layers,
        gpu_id=0,
        rdma_port=None,
    ):
        """
        Initialize the CacheMessager object.

        Args:
            splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
            transfer_protocol (str): support ipc and rdma
            engine_worker_queue_port (int): engine_worker_queue port
            gpu_cache_kvs (dict): GPU kv cache
            rank (int): current rank
            nranks (int): global rank number
            num_layers (int): model layer number
            gpu_id (int, optional): GPU ID
            rdma_port (int, optional): RDMA port

        Returns:
            None
        """
        self.splitwise_role = splitwise_role
        self.gpu_cache_kvs = gpu_cache_kvs
        self.rank = rank
        self.nranks = nranks
        address = (pod_ip, engine_worker_queue_port)
        self.engine_worker_queue = EngineWorkerQueue(
            address=address,
            is_server=False,
            num_client=self.nranks,
            client_id=self.rank,
            local_data_parallel_id=local_data_parallel_id,
        )
        transfer_protocol = transfer_protocol.split(",")

        logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")

        # 1. initialize the cache_k_ptr_list and cache_v_ptr_list
        self.num_layers = num_layers
        cache_k_ptr_list = []
        cache_v_ptr_list = []
        cache_k = []
        cache_v = []
        self.messager = {}
        for layer_idx in range(self.num_layers):
            key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
            val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
            cache_k.append(key_cache)
            cache_v.append(val_cache)
            cache_k_ptr_list.append(key_cache.data_ptr())
            cache_v_ptr_list.append(val_cache.data_ptr())
        cache_k_ptr_list = np.array(cache_k_ptr_list)
        cache_v_ptr_list = np.array(cache_v_ptr_list)

        # 2. initialize the block_bytes
        cache_shape = key_cache.shape
        max_block_num = cache_shape[0]
        block_bytes = math.prod(cache_shape[1:])
        if key_cache.dtype == paddle.bfloat16:
            block_bytes *= 2
        logger.info(
            f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
            f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
        )
        self.block_bytes = block_bytes

        # 3. initialize the messager
        for protocol in transfer_protocol:
            if protocol == "ipc":
                self.messager[protocol] = IPCCommManager(
                    self.rank,
                    gpu_id,
                    cache_k,
                    cache_v,
                )
                local_device_id = int(str(cache_k[0].place)[-2])
                logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")

            elif protocol == "rdma":
                logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")

                self.messager[protocol] = RDMACommManager(
                    splitwise_role,
                    rank,
                    gpu_id,
                    cache_k_ptr_list,
                    cache_v_ptr_list,
                    max_block_num,
                    block_bytes,
                    rdma_port,
                )

        self.gpu_id = gpu_id
        self.cache_info = dict()
        self.rank_id = self.rank + local_data_parallel_id * self.nranks

        if self.splitwise_role != "mixed":
            connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
            connect_rdma_thread.daemon = True
            connect_rdma_thread.start()

        logger.info(f"cache messager init finished, use {transfer_protocol}")

    def prefill_layerwise_send_cache_thread(self):
        """
        layerwise_send_cache_thread:
        send cache to other instance
        """
        try:
            prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
            prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
            prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.rank_id}.{self.gpu_id}"
            prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
            step_shm_value = IPCSignal(
                name=f"splitwise_complete_prefilled_step_{self.rank_id}",
                array=prefilled_step_idx_data,
                dtype=np.int32,
                suffix=self.gpu_id,
                create=not shared_memory_exists(prefilled_step_name),
            )
            layer_shm_value = IPCSignal(
                name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
                array=prefilled_layer_idx_data,
                dtype=np.int32,
                suffix=self.gpu_id,
                create=not shared_memory_exists(prefilled_layer_name),
            )
            logger.info(f"splitwise_complete_prefilled_step_{self.rank_id}, gpu_id: {self.gpu_id}")

            step_shm_value.value[0] = -1
            layer_shm_value.value[0] = -1

            self.last_step_idx = -1
            self.last_layer_idx = -1  # int32

            max_step_idx = 100003
            engine_recycled_count = 0

            while True:

                cache_info = self.engine_worker_queue.get_cache_info()

                if cache_info:
                    logger.debug(f"cache info {cache_info}")
                    for info in cache_info:
                        if info["request_id"] in self.cache_info:
                            self.cache_info[info["request_id"]].update(info)
                            current_info = self.cache_info[info["request_id"]]
                            if "dest_block_ids" in current_info and "src_block_ids" in current_info:
                                current_src_blocks = current_info["src_block_ids"][
                                    -len(current_info["dest_block_ids"]) :
                                ]
                                current_info["src_block_ids"] = current_src_blocks
                                current_info["status"] = "init"
                                logger.info(f"start cache_infos: {current_info}")
                            self.cache_info[info["request_id"]] = current_info
                        else:
                            self.cache_info[info["request_id"]] = info
                prefilled_layer_idx = layer_shm_value.value[0]
                prefilled_step_idx = step_shm_value.value[0]
                if prefilled_layer_idx == self.num_layers - 1:
                    time.sleep(0.001)
                    prefilled_layer_idx = layer_shm_value.value[0]
                    prefilled_step_idx = step_shm_value.value[0]

                if prefilled_step_idx == -1:
                    time.sleep(0.001)
                    continue
                if not self.cache_info:
                    time.sleep(0.001)
                    continue

                if self.last_step_idx > prefilled_step_idx:
                    engine_recycled_count += 1
                self.last_step_idx = prefilled_step_idx  # only copy value read from shm memory
                prefilled_step_idx = (
                    prefilled_step_idx + max_step_idx * engine_recycled_count
                )  # remap prefilled_step_idx for comparison

                logger.debug(
                    f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
                    f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
                )
                for req_id, item in list(self.cache_info.items()):
                    if "status" not in item:
                        continue
                    if "layer_idx" not in item:
                        item["layer_idx"] = 0
                    if item["status"] == "error":
                        del self.cache_info[req_id]
                        continue
                    if item["current_id"] > prefilled_step_idx:
                        continue
                    current_transfer_protocol = item["transfer_protocol"]
                    if item["transfer_protocol"] == "rdma":
                        target_ip = item["ip"]
                        target_id = int(item["rdma_ports"][self.rank])
                        status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
                        if not status:
                            logger.error(f"connect to {target_ip}:{target_id} failed")
                            item["status"] = "error"
                            self.engine_worker_queue.finish_request_barrier.wait()
                            if self.rank == 0:
                                self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
                            continue
                    elif item["transfer_protocol"] == "ipc":
                        target_ip = "0.0.0.0"
                        target_id = int(item["device_ids"][self.rank])
                    src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
                    dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
                    if item["current_id"] < prefilled_step_idx:
                        current_layer_idx = self.num_layers
                    else:
                        current_layer_idx = prefilled_layer_idx + 1

                    for layer_idx in range(item["layer_idx"], current_layer_idx):
                        tic = time.time()
                        return_code = self.messager[current_transfer_protocol].write_cache(
                            target_ip,
                            target_id,
                            src_block_ids,
                            dest_block_ids,
                            layer_idx,
                        )
                        if return_code != 0:
                            item["status"] = "error"
                            self.engine_worker_queue.finish_request_barrier.wait()
                            if self.rank == 0:
                                self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
                            logger.error(
                                f"write cache failed, layer_idx: {layer_idx}, "
                                f"req_id: {item['request_id']}, dest_ip: {target_ip}"
                            )
                            break

                        tok = time.time()
                        cost_time = tok - tic
                        block_num = len(src_block_ids)
                        avg_time_per_block = cost_time * 1000 / block_num  # ms
                        send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time  # GB/s
                        logger.debug(
                            f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
                            f" {current_transfer_protocol}"
                            f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
                            f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
                        )
                    item["layer_idx"] = current_layer_idx
                    if item["layer_idx"] == self.num_layers:
                        if item["transfer_protocol"] == "ipc":
                            self.messager["ipc"].write_block_by_sync(target_id)
                        logger.info(f"finish write cache {item['request_id']}")
                        self.engine_worker_queue.finish_request_barrier.wait()
                        if self.rank == 0:
                            # to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
                            self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
                            logger.info(f"put write cache {item['request_id']}")
                        del self.cache_info[req_id]
                self.last_layer_idx = prefilled_layer_idx

        except Exception as e:
            logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}")

    def _handle_connect_task(self):
        while True:
            try:
                task = self.engine_worker_queue.get_connect_rdma_task()
                if task is None:
                    time.sleep(0.001)
                    continue
                logger.info(f"_handle_connect_task recv task: {task}")
                task_id = task["task_id"]
                ip, rdma_port = task["ip"], task["rdma_port"]
                status = self.messager["rdma"].connect(ip, rdma_port)
                if not status:
                    response = {"task_id": task_id, "success": False}
                else:
                    response = {"task_id": task_id, "success": True}
                self.engine_worker_queue.put_connect_rdma_task_response(response)
            except Exception as e:
                logger.error(f"handle_connect_task has exception: {e}")


class CacheMessagerV1:
    """
    CacheMessager is used to send the cache data between the engine worker and the cache server.
    """

    def __init__(
        self,
        splitwise_role,
        transfer_protocol,
        pod_ip,
        engine_worker_queue_port,
        local_data_parallel_id,
        gpu_cache_kvs,
        rank,
        nranks,
        num_layers,
        gpu_id=0,
        block_size=64,
        rdma_port=None,
    ):
        """
        Initialize the CacheMessager object.

        Args:
            splitwise_role (str): splitwise_role only can be 'prefill' or 'decode'.
            transfer_protocol (str): support ipc and rdma
            engine_worker_queue_port (int): engine_worker_queue port
            gpu_cache_kvs (dict): GPU kv cache
            rank (int): current rank
            nranks (int): global rank number
            num_layers (int): model layer number
            gpu_id (int, optional): GPU ID
            rdma_port (int, optional): RDMA port

        Returns:
            None
        """
        self.splitwise_role = splitwise_role
        self.gpu_cache_kvs = gpu_cache_kvs
        self.rank = rank
        self.nranks = nranks
        address = (pod_ip, engine_worker_queue_port)
        self.engine_worker_queue = EngineWorkerQueue(
            address=address,
            is_server=False,
            num_client=self.nranks,
            client_id=self.rank,
            local_data_parallel_id=local_data_parallel_id,
        )
        self.block_size = block_size
        transfer_protocol = transfer_protocol.split(",")

        logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")

        # 1. initialize the cache_k_ptr_list and cache_v_ptr_list
        self.num_layers = num_layers
        cache_k_ptr_list = []
        cache_v_ptr_list = []
        cache_k = []
        cache_v = []
        self.messager = {}
        for layer_idx in range(self.num_layers):
            key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
            val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
            cache_k.append(key_cache)
            cache_v.append(val_cache)
            cache_k_ptr_list.append(key_cache.data_ptr())
            cache_v_ptr_list.append(val_cache.data_ptr())
        cache_k_ptr_list = np.array(cache_k_ptr_list)
        cache_v_ptr_list = np.array(cache_v_ptr_list)

        # 2. initialize the block_bytes
        cache_shape = key_cache.shape
        max_block_num = cache_shape[0]
        block_bytes = math.prod(cache_shape[1:])
        if key_cache.dtype == paddle.bfloat16:
            block_bytes *= 2
        logger.info(
            f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
            f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
        )
        self.block_bytes = block_bytes

        # 3. initialize the messager
        for protocol in transfer_protocol:
            if protocol == "ipc":
                self.messager[protocol] = IPCCommManager(
                    self.rank,
                    gpu_id,
                    cache_k,
                    cache_v,
                )
                local_device_id = int(str(cache_k[0].place)[-2])
                logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")

            elif protocol == "rdma":
                logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")

                self.messager[protocol] = RDMACommManager(
                    splitwise_role,
                    rank,
                    gpu_id,
                    cache_k_ptr_list,
                    cache_v_ptr_list,
                    max_block_num,
                    block_bytes,
                    rdma_port,
                )

        self.gpu_id = gpu_id
        self.cache_info = dict()
        self.rank_id = self.rank + local_data_parallel_id * self.nranks
        self.engine_cache_task_thread_lock = threading.Lock()
        self.engine_cache_tasks = [dict() for _ in range(512)]
        self.idx_cache_task_dict = {}
        self.cache_prefilled_engine_ids_queue = queue.Queue()  # keep batch slot index for each prefill step
        if splitwise_role == "prefill":
            consume_signals_thread = threading.Thread(target=self.consume_signals)
            consume_signals_thread.daemon = True
            consume_signals_thread.start()
            add_cache_task_thread = threading.Thread(target=self._add_cache_task_thread)
            add_cache_task_thread.daemon = True
            add_cache_task_thread.start()

        if self.splitwise_role != "mixed":
            connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
            connect_rdma_thread.daemon = True
            connect_rdma_thread.start()

        logger.info(f"cache messager init finished, use {transfer_protocol}")

    def _add_cache_task_thread(self):
        while True:
            try:
                cache_info = self.engine_worker_queue.get_cache_info()
                self.engine_worker_queue.finish_add_cache_task_barrier.wait()
                finished_add_cache_task_req_ids = []
                if cache_info:
                    for info in cache_info:
                        if info["request_id"] in self.cache_info:
                            self.cache_info[info["request_id"]].update(info)
                            current_info = self.cache_info[info["request_id"]]
                            assert "dest_block_ids" in current_info and "src_block_ids" in current_info
                            finished_add_cache_task_req_ids.append(info["request_id"])
                            decode_cached_block_num = len(current_info["src_block_ids"]) - len(
                                current_info["dest_block_ids"]
                            )
                            padding_decode_block_ids = [-1 for i in range(decode_cached_block_num)] + current_info[
                                "dest_block_ids"
                            ]
                            current_info["dest_block_ids"] = padding_decode_block_ids
                            current_info["decode_cached_tokens"] = decode_cached_block_num * self.block_size
                            current_info["sended_layer_id"] = -1
                            current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
                            current_info["status"] = "init"
                            logger.info(f"finish add cache task: {current_info}")
                            self.cache_info[info["request_id"]] = current_info
                            self.idx_cache_task_dict[current_info["current_id"]] = current_info
                        else:
                            self.cache_info[info["request_id"]] = info
                    if self.rank == 0 and finished_add_cache_task_req_ids:
                        self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
                else:
                    time.sleep(0.001)
            except Exception as e:
                logger.info(f"add cache task occured error: {e},  {traceback.format_exc()!s}.")

    def prefill_layerwise_send_cache_thread(self):
        """
        layerwise_send_cache_thread:
        send cache to other instance
        """
        while True:
            try:
                engine_indexes = self.cache_prefilled_engine_ids_queue.get()
                self.engine_worker_queue.finish_request_barrier.wait()
                block_start_end_list = []
                current_prefilled_token_num_list = []
                for engine_index in engine_indexes:
                    assert engine_index in self.idx_cache_task_dict
                    block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
                    prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
                    if (
                        prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
                    ):  # all chunks have been prefilled
                        block_id_end = len(self.idx_cache_task_dict[engine_index]["src_block_ids"])
                    else:
                        block_id_end = prefilled_token_num // self.block_size  # [block_id_start, block_id_end)
                    block_start_end_list.append((block_id_start, block_id_end))
                    current_prefilled_token_num_list.append(prefilled_token_num)
                while True:  # from layer0 to last layer
                    sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
                    start_layer_idx = sended_layer_idx + 1
                    with self.engine_cache_task_thread_lock:  # to check end_layer_idx
                        prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
                        if sended_layer_idx > prefilled_layer_idx:  # computation must in next chunk
                            logger.info(
                                f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
                            )
                            assert (
                                current_prefilled_token_num_list[0]
                                < self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
                            ), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
                            end_layer_idx = self.num_layers - 1  # [start_layer_idx, end_layer_idx)
                        else:
                            end_layer_idx = prefilled_layer_idx
                    if sended_layer_idx == prefilled_layer_idx:  # computation not in next layer
                        time.sleep(0.01)
                    for layer_idx in range(start_layer_idx, end_layer_idx + 1):
                        for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
                            engine_index = engine_indexes[i]
                            task = self.idx_cache_task_dict[engine_index]
                            req_id = task["request_id"]
                            if (
                                block_id_start >= block_id_end
                            ):  # no blocks need to transfer for this request in this chunk
                                task["sended_layer_id"] += 1
                                assert task["sended_layer_id"] == layer_idx
                                if task["sended_layer_id"] == self.num_layers - 1:
                                    task["sended_layer_id"] = -1
                                continue
                            else:
                                current_transfer_protocol = task["transfer_protocol"]
                                if task["transfer_protocol"] == "rdma":
                                    target_ip = task["ip"]
                                    target_id = int(task["rdma_ports"][self.rank])
                                    if task["status"] == "error":
                                        continue
                                    status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
                                    if not status:
                                        logger.error(f"connect to {target_ip}:{target_id} failed")
                                        task["status"] = "connection error"
                                        continue
                                elif task["transfer_protocol"] == "ipc":
                                    target_ip = "0.0.0.0"
                                    target_id = int(task["device_ids"][self.rank])

                                src_block_ids = task["src_block_ids"][block_id_start:block_id_end]
                                dest_block_ids = task["dest_block_ids"][block_id_start:block_id_end]
                                src_block_ids = paddle.to_tensor(src_block_ids, dtype="int32", place="cpu")
                                dest_block_ids = paddle.to_tensor(dest_block_ids, dtype="int32", place="cpu")

                                logger.info(
                                    f"start write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id}, block_id_start {block_id_start} block_id_end {block_id_end}"
                                )
                                tic = time.time()
                                return_code = self.messager[current_transfer_protocol].write_cache(
                                    target_ip,
                                    target_id,
                                    src_block_ids,
                                    dest_block_ids,
                                    layer_idx,
                                )
                                if return_code != 0:
                                    task["status"] = "write cache error"
                                    logger.error(
                                        f"write cache failed, layer_idx: {layer_idx}, req_id: {req_id}, dest_ip: {target_ip}, block_id_start {block_id_start} block_id_end {block_id_end}"
                                    )
                                tok = time.time()
                                cost_time = tok - tic
                                block_num = len(src_block_ids)
                                avg_time_per_block = cost_time * 1000 / block_num  # ms
                                send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time  # GB/s
                                logger.debug(
                                    f"finish write cache for a layer, {req_id}, {layer_idx}, {target_ip}, {target_id},"
                                    f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
                                    f"avg_time per block(ms): {round(avg_time_per_block, 5)} block_id_start {block_id_start} block_id_end {block_id_end}"
                                )

                                task["sended_layer_id"] += 1
                                assert task["sended_layer_id"] == layer_idx
                                if task["sended_layer_id"] == self.num_layers - 1:
                                    self.idx_cache_task_dict[engine_index]["sended_block_num"] += (
                                        block_id_end - block_id_start
                                    )
                                    if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
                                        if task["status"] != "error":
                                            task["status"] = "finished"
                                            logger.info(
                                                f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
                                            )
                                    else:
                                        task["sended_layer_id"] = -1
                    if end_layer_idx == self.num_layers - 1:
                        with self.engine_cache_task_thread_lock:
                            for engine_idx in engine_indexes:
                                task = self.idx_cache_task_dict[engine_idx]
                                if task["status"] == "finished" or ("error" in task["status"]):
                                    target_id = int(task["rdma_ports"][self.rank])
                                    if task["transfer_protocol"] == "ipc":
                                        self.messager["ipc"].write_block_by_sync(target_id)
                                    if self.rank == 0:
                                        # to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
                                        self.engine_worker_queue.put_finished_req(
                                            [(task["request_id"], task["status"])]
                                        )
                                        logger.info(f"put write cache {task['request_id']}, status {task['status']}")
                                    self.engine_cache_tasks[task["current_id"]] = dict()
                                    del self.cache_info[task["request_id"]]
                                    del self.idx_cache_task_dict[task["current_id"]]
                        break
            except Exception as e:
                logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
                time.sleep(0.01)

    def consume_signals(self):
        paddle.device.set_device("cpu")
        kv_signal_data = paddle.full(shape=[512 * 3 + 2], fill_value=-1, dtype="int32")
        while True:
            try:
                get_output_kv_signal(kv_signal_data, self.rank_id, 0)  # wait_flag
                if not self.cache_info:
                    time.sleep(0.01)
                    continue
                tasks_count = kv_signal_data[0]
                if tasks_count == -1:
                    time.sleep(0.001)
                    continue
                layer_id = kv_signal_data[1].numpy().tolist()
                if layer_id == self.num_layers - 1:
                    logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
                batch_engine_ids = []
                with self.engine_cache_task_thread_lock:
                    for bi in range(tasks_count):
                        engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
                        chuck_token_offset = kv_signal_data[3 * bi + 3].numpy().tolist()
                        current_seq_len = kv_signal_data[3 * bi + 4].numpy().tolist()
                        self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
                        self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
                            chuck_token_offset + current_seq_len
                        )
                        batch_engine_ids.append(engine_idx)
                    if layer_id == 0:
                        self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
            except Exception as e:
                logger.error(f"Consume signals get exception: {e}")

    def _handle_connect_task(self):
        while True:
            try:
                task = self.engine_worker_queue.get_connect_rdma_task()
                if task is None:
                    time.sleep(0.001)
                    continue
                logger.info(f"_handle_connect_task recv task: {task}")
                task_id = task["task_id"]
                ip, rdma_port = task["ip"], task["rdma_port"]
                status = self.messager["rdma"].connect(ip, rdma_port)
                if not status:
                    response = {"task_id": task_id, "success": False}
                else:
                    response = {"task_id": task_id, "success": True}
                self.engine_worker_queue.put_connect_rdma_task_response(response)
            except Exception as e:
                logger.error(f"handle_connect_task has exception: {e}")


def main():
    device = args.device_id
    rank = args.rank
    paddle.set_device(f"gpu:{device}")
    cache_type = args.cache_dtype
    speculative_config = SpeculativeConfig(args.speculative_config)
    num_extra_layers = speculative_config.num_extra_cache_layer
    num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
    gpu_cache_kvs = {}
    gpu_cache_k_tensors = []
    gpu_cache_v_tensors = []

    for i in range(args.num_layers + num_extra_layers):
        num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks

        gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
            shape=[
                num_gpu_blocks,
                args.kv_num_head,
                args.block_size,
                args.head_dim,
            ],
            fill_value=0,
            dtype=cache_type,
        )
        gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
        gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
            shape=[
                num_gpu_blocks,
                args.kv_num_head,
                args.block_size,
                args.head_dim,
            ],
            fill_value=0,
            dtype=cache_type,
        )
        gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])

        set_data_ipc(
            gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
            f"key_caches_{i}_rank{rank}.device{device}",
        )
        set_data_ipc(
            gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
            f"value_caches_{i}_rank{rank}.device{device}",
        )
    cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
    logger.info(f"device :{device}")
    logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
    logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")

    if envs.ENABLE_V1_KVCACHE_SCHEDULER:
        cache_messager = CacheMessagerV1(
            splitwise_role=args.splitwise_role,
            transfer_protocol=args.protocol,
            pod_ip=args.pod_ip,
            engine_worker_queue_port=args.engine_worker_queue_port,
            local_data_parallel_id=args.local_data_parallel_id,
            gpu_cache_kvs=gpu_cache_kvs,
            rank=rank,
            nranks=args.mp_num,
            num_layers=args.num_layers + num_extra_layers,
            gpu_id=device,
            rdma_port=args.rdma_port,
        )
    else:
        cache_messager = CacheMessager(
            splitwise_role=args.splitwise_role,
            transfer_protocol=args.protocol,
            pod_ip=args.pod_ip,
            engine_worker_queue_port=args.engine_worker_queue_port,
            local_data_parallel_id=args.local_data_parallel_id,
            gpu_cache_kvs=gpu_cache_kvs,
            rank=rank,
            nranks=args.mp_num,
            num_layers=args.num_layers + num_extra_layers,
            gpu_id=device,
            rdma_port=args.rdma_port,
        )

    cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
    cache_ready_signal = IPCSignal(
        name="cache_ready_signal",
        array=cache_ready_signal_data,
        dtype=np.int32,
        suffix=args.engine_pid,
        create=False,
    )
    cache_ready_signal.value[rank] = 1
    if args.splitwise_role == "mixed":
        while True:
            time.sleep(1)
    cache_messager.prefill_layerwise_send_cache_thread()


if __name__ == "__main__":

    args = parse_args()
    rank_id = args.rank + args.local_data_parallel_id * args.mp_num
    logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")

    logger.info("create cache messager...")
    logger.info(f"{args}")
    main()
